{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 图片数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "\n",
    "@dataclass\n",
    "class Preprocessing:\n",
    "    width: int\n",
    "    height: int\n",
    "    channels: int\n",
    "    mean: tuple[float] = (0,)\n",
    "    std: tuple[float] = (1,)\n",
    "    layout: str = \"HWC\"\n",
    "    name: str = \"data\"\n",
    "    format: str = \"RGB\"\n",
    "\n",
    "    def __post_init__(self):\n",
    "        if self.layout == \"HWC\":\n",
    "            self.shape = self.height, self.width, self.channels\n",
    "        elif self.layout == \"CHW\":\n",
    "            self.shape = self.channels, self.height, self.width\n",
    "        else:\n",
    "            raise ValueError(f\"Unknown layout: {self.layout}\")\n",
    "\n",
    "    def load(self, path: str | Path) -> np.ndarray:\n",
    "        \"\"\"加载图片\"\"\"\n",
    "        img = Image.open(path).resize((self.width, self.height)) # uint8 数据\n",
    "        if self.format == \"GRAY\":\n",
    "            img = img.convert(\"L\")\n",
    "            img = np.expand_dims(img, axis=-1) # WH->HWC\n",
    "        elif self.format == \"RGB\":\n",
    "            img = np.array(img.convert(\"RGB\")) # WHC->HWC\n",
    "        elif self.format == \"BGR\":\n",
    "            img = np.array(img.convert(\"RGB\")) # WHC->HWC\n",
    "            img = img[..., ::-1] # RGB 转 BGR\n",
    "        else:\n",
    "            raise TypeError(f'暂未支持数据布局 {self.format}')\n",
    "        return img\n",
    "    \n",
    "    def __call__(self, path: str | Path) -> np.ndarray:\n",
    "        img = self.load(path)/255.0 # 归一化（将 uint8 数据归一化到 [0, 1]，这是神经网络的标准输入格式）\n",
    "        img = (img - self.mean) / self.std # 标准化，使数据分布更接近标准正态分布\n",
    "        img = img.astype(\"float32\")\n",
    "        if self.layout == \"CHW\":\n",
    "            img = img.transpose(2, 0, 1) # HWC->CHW\n",
    "        return img\n",
    "\n",
    "    def torch_call(self, path: str | Path) -> \"torch.Tensor\":\n",
    "        assert self.layout == \"CHW\", \"torchvision 只支持 CHW 布局\"\n",
    "        from torchvision.transforms import v2\n",
    "        import torch\n",
    "        from torch import nn\n",
    "        inp = self.load(path)\n",
    "        return nn.Sequential(\n",
    "            v2.ToImage(),\n",
    "            v2.ToDtype(torch.float32, scale=True),\n",
    "            v2.Normalize(self.mean, self.std)\n",
    "        )(inp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from PIL import Image\n",
    "root_dir = Path('../../images')\n",
    "mean = [0.485, 0.456, 0.406]\n",
    "std = [0.229, 0.224, 0.225]\n",
    "layout = \"CHW\"\n",
    "preprocessing = Preprocessing(32, 32, 3, mean, std, layout)\n",
    "# torch_inp = preprocessing.torch_call(root_dir/\"Giant_Panda_in_Beijing_Zoo_1.jpg\")\n",
    "inp = preprocessing(root_dir/\"Giant_Panda_in_Beijing_Zoo_1.jpg\")\n",
    "# np.testing.assert_almost_equal(inp, torch_inp.numpy(), decimal=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 32, 32)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inp.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
